[Feat] Adds LongCat-AudioDiT pipeline #13390
[Feat] Adds LongCat-AudioDiT pipeline #13390RuixiangMa wants to merge 14 commits intohuggingface:mainfrom
Conversation
Signed-off-by: Lancer <maruixiang6688@gmail.com>
9c4613f to
d2a2621
Compare
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| ) | ||
|
|
||
|
|
||
| def _pixel_shuffle_1d(hidden_states: torch.Tensor, factor: int) -> torch.Tensor: |
There was a problem hiding this comment.
Similarly, I think we should inline _pixel_shuffle_1d in UpsampleShortcut following #13390 (comment).
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/autoencoders/autoencoder_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| self.time_embed = AudioDiTTimestepEmbedding(dim) | ||
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | ||
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | ||
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | ||
| self.blocks = nn.ModuleList( |
There was a problem hiding this comment.
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( | |
| self.time_embed = AudioDiTTimestepEmbedding(dim) | |
| self.input_embed = AudioDiTEmbedder(latent_dim, dim) | |
| self.text_embed = AudioDiTEmbedder(dit_text_dim, dim) | |
| self.rotary_embed = AudioDiTRotaryEmbedding(dim_head, 2048, base=100000.0) | |
| self.blocks = nn.ModuleList( |
See #13390 (comment).
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| batch_size = hidden_states.shape[0] | ||
| if timestep.ndim == 0: | ||
| timestep = timestep.repeat(batch_size) | ||
| timestep_embed = self.time_embed(timestep) | ||
| text_mask = encoder_attention_mask.bool() | ||
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
There was a problem hiding this comment.
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) | |
| batch_size = hidden_states.shape[0] | |
| if timestep.ndim == 0: | |
| timestep = timestep.repeat(batch_size) | |
| timestep_embed = self.time_embed(timestep) | |
| text_mask = encoder_attention_mask.bool() | |
| encoder_hidden_states = self.text_embed(encoder_hidden_states, text_mask) |
Can you also refactor forward here so that it is better organized, following #13390 (comment)? See for example the QwenImageTransformer2DModel.forward method:
There was a problem hiding this comment.
Reorganized parts of forward incrementally; kept the current structure otherwise to avoid unnecessary behavioral churn.
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
src/diffusers/models/transformers/transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | ||
| def test_layerwise_casting_memory(self): | ||
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") | ||
|
|
||
| def test_layerwise_casting_training(self): | ||
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") | ||
|
|
||
| def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): | ||
| pytest.skip( | ||
| "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." | ||
| ) |
There was a problem hiding this comment.
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | |
| def test_layerwise_casting_memory(self): | |
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting memory tests yet.") | |
| def test_layerwise_casting_training(self): | |
| pytest.skip("LongCatAudioDiTTransformer does not support standard layerwise casting training tests yet.") | |
| def test_group_offloading_with_layerwise_casting(self, *args, **kwargs): | |
| pytest.skip( | |
| "LongCatAudioDiTTransformer does not support combined group offloading and layerwise casting tests yet." | |
| ) | |
| class TestLongCatAudioDiTTransformerMemory(LongCatAudioDiTTransformerTesterConfig, MemoryTesterMixin): | |
| pass |
Layerwise casting should work if #13390 (comment) is applied.
There was a problem hiding this comment.
I removed the layerwise casting training and combined group-offloading/layerwise-casting skips after updating the dtype handling. I kept test_layerwise_casting_memory skipped
because the tiny transformer config does not provide stable peak-memory behavior for that assertion.
tests/models/transformers/test_models_transformer_longcat_audio_dit.py
Outdated
Show resolved
Hide resolved
dg845
left a comment
There was a problem hiding this comment.
Thanks for your continued work on this! Left some suggestions that should help LongCatAudioDiTPipeline support model offloading, layerwise casting, etc.
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
|
||
| @classmethod | ||
| @validate_hf_hub_args | ||
| def from_pretrained( |
There was a problem hiding this comment.
can you add a conversion script?
our pipeline should not define from_pretrained method
There was a problem hiding this comment.
can you add a conversion script? our pipeline should not define
from_pretrainedmethod
Added it and tested.
|
@claude can you help with a review here? |
|
I'll analyze this and get back to you. |
| timesteps = self.scheduler.timesteps | ||
| self._num_timesteps = len(timesteps) | ||
|
|
||
| for i, t in enumerate(timesteps): |
There was a problem hiding this comment.
Can you add support for a progress bar here? For example, here is how Flux 2 implements a progress bar with self.progress_bar:
diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 955 to 956 in 5063aa5
This will make it easier to track progress during inference.
There was a problem hiding this comment.
Can you add support for a progress bar here? For example, here is how Flux 2 implements a progress bar with
self.progress_bar:diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py
Lines 955 to 956 in 5063aa5
This will make it easier to track progress during inference.
Done
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 13.68it/s]

What does this PR do?
Adds LongCat-AudioDiT model support to diffusers.
Although LongCat-AudioDiT can be used for TTS-like generation, it is fundamentally a diffusion-based audio generation model (text conditioning + iterative latent denoising + VAE decoding) rather than a conventional autoregressive TTS model, so i think it fits naturally into diffusers.
Test
Result
longcat.wav
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.